{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Solve MoutainCar-v0 using a Fixed Deterministic Policy\n",
    "\n",
    "This notebook solves 'MountainCar-v0' using a fixed deterministic policy.\n",
    "\n",
    "### Policy\n",
    "\n",
    "Push right when\n",
    "\n",
    "$\\min \\left(-0.09 \\left(\\text{position} + 0.25 \\right) ^ 2 + 0.03, 0.3 \\left(\\text{position} + 0.9 \\right) ^ 4 - 0.008 \\right) \\le \\text{velocity} \\le -0.07 \\left(\\text{position} + 0.38 \\right) ^ 2 + 0.07$,\n",
    "\n",
    "otherwise push left.\n",
    "\n",
    "\n",
    "This solution has been published in the following book:\n",
    "\n",
    "    @book{xiao2019,\n",
    "     title = {Reinforcement Learning: Theory and {Python} Implementation},\n",
    "     author = {Zhiqing Xiao}\n",
    "     year = 2019,\n",
    "     publisher = {China Machine Press},\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0]"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import itertools\n",
    "import numpy as np\n",
    "import gym\n",
    "np.random.seed(0)\n",
    "env = gym.make('MountainCar-v0')\n",
    "env.seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Agent:\n",
    "    def decide(self, observation):\n",
    "        position, velocity = observation\n",
    "        lb = min(-0.09 * (position + 0.25) ** 2 + 0.03,\n",
    "                0.3 * (position + 0.9) ** 4 - 0.008)\n",
    "        ub = -0.07 * (position + 0.38) ** 2 + 0.07\n",
    "        if lb < velocity < ub:\n",
    "            action = 2 # push right\n",
    "        else:\n",
    "            action = 0 # push left\n",
    "        return action\n",
    "\n",
    "agent = Agent()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def play_once(env, agent, render=False, verbose=False):\n",
    "    observation = env.reset()\n",
    "    episode_reward = 0.\n",
    "    for step in itertools.count():\n",
    "        if render:\n",
    "            env.render()\n",
    "        action = agent.decide(observation)\n",
    "        observation, reward, done, _ = env.step(action)\n",
    "        episode_reward += reward\n",
    "        if done:\n",
    "            break\n",
    "    if verbose:\n",
    "        print('get {} rewards in {} steps'.format(\n",
    "                episode_reward, step + 1))\n",
    "    return episode_reward"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test 100 episode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "average episode rewards = -102.61\n"
     ]
    }
   ],
   "source": [
    "episode_rewards = [play_once(env, agent) for _ in range(100)]\n",
    "print('average episode rewards = {}'.format(np.mean(episode_rewards)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test 100000 episodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "average episode rewards = -105.75467\n"
     ]
    }
   ],
   "source": [
    "episode_rewards = [play_once(env, agent) for _ in range(100000)]\n",
    "print('average episode rewards = {}'.format(np.mean(episode_rewards)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
